{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use your Likelihood Function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you have your own LAN, or any class with a `predict_on_batch()` method which you can call to get back the log-likelihood of a dataset, you can pass it as an argument to the `HDDMnn`, `HDDMnnRegressor` or `HDDMnnStimCoding` classes. \n",
    "\n",
    "In this document we provide you with a simple complete example."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Such a new model will be called `custom`, for our purposes. Let's say we trained a LAN for it. \n",
    "\n",
    "We need two components to use our `custom` LAN. \n",
    "\n",
    "1. A config dictionary for the `custom` model.\n",
    "2. A pretrained LAN with a `predict_on_batch` method."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Construct config dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import hddm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Access the config dictionary for all models included in `HDDM` by calling the  `hddm.model_config.model_config` dictionary. Learn more about this object in **lan tutorial concerning new models**. \n",
    "\n",
    "We now generate a custom `model_config` for our new model. Let's start with the minimal `dictionary` that you need to befine."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_model_config = {}\n",
    "\n",
    "# parameter names associated to your model\n",
    "my_model_config[\"params\"] = [\"v\", \"a\", \"z\", \"t\"]\n",
    "\n",
    "# the parameter boundaries you used for training your LAN\n",
    "my_model_config[\"param_bounds\"] = [[-3.0, 0.3, 0.1, 1e-3], [3.0, 2.5, 0.9, 2.0]]\n",
    "\n",
    "# add a boundary function\n",
    "my_model_config[\"boundary\"] = hddm.simulators.boundary_functions.constant\n",
    "\n",
    "# suggestion for which parameters to include\n",
    "# via the include statement of an HDDM model\n",
    "# (usually you want all of the parameters from above)\n",
    "my_model_config[\"hddm_include\"] = [\"z\"]\n",
    "\n",
    "# choice labels (what your simulator spits out)\n",
    "my_model_config[\"choices\"] = [-1, 1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can add extra options, useful for example to improve performance of the sampler."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# specifies parameters which the sampler should tansform\n",
    "my_model_config[\"params_trans\"] = [0, 0, 1, 0]\n",
    "\n",
    "# adds sampler settings for each parameter\n",
    "my_model_config[\"slice_widths\"] = {\n",
    "    \"v\": 1.5,\n",
    "    \"v_std\": 1,\n",
    "    \"a\": 1,\n",
    "    \"a_std\": 1,\n",
    "    \"z\": 0.1,\n",
    "    \"z_trans\": 0.2,\n",
    "    \"t\": 0.01,\n",
    "    \"t_std\": 0.15,\n",
    "}\n",
    "\n",
    "# set sampler starting points manually for each parameter\n",
    "my_model_config[\"params_default\"] = [0.0, 1.0, 0.5, 1e-3]\n",
    "\n",
    "# set a (reasonable) upper limit of group level standard deviations,\n",
    "# this can help with sampler stability\n",
    "my_model_config[\"params_std_upper\"] = [1.5, 1.0, None, 1.0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To make the example complete, here is a code snippet to load in a `torch` model. We are simply using the `ddm` network here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "custom_network = hddm.torch.mlp_inference_class.load_torch_mlp(model=\"ddm\")\n",
    "# any class with a valid predict on batch function works here"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**NOTE:**\n",
    "\n",
    "Above, for simplicity, we load a network that is already available in HDDM. You would call your own code instead."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hddm.simulators.hddm_dataset_generators import simulator_h_c\n",
    "\n",
    "# Simulate some data:\n",
    "model = \"ddm\"\n",
    "n_subjects = 1\n",
    "n_samples_by_subject = 500\n",
    "\n",
    "data, full_parameter_dict = simulator_h_c(\n",
    "    n_subjects=n_subjects,\n",
    "    n_samples_by_subject=n_samples_by_subject,\n",
    "    model=model,\n",
    "    p_outlier=0.00,\n",
    "    conditions=None,\n",
    "    depends_on=None,\n",
    "    regression_models=None,\n",
    "    regression_covariates=None,\n",
    "    group_only_regressors=False,\n",
    "    group_only=None,\n",
    "    fixed_at_default=None,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize HDDM Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now you are ready to load a HDDM model with your `custom` LAN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the HDDM model\n",
    "hddmnn_model = hddm.HDDMnn(\n",
    "    data=data,\n",
    "    informative=False,\n",
    "    include=my_model_config[\n",
    "        \"hddm_include\"\n",
    "    ],  # Note: This include statement is an example, you may pick any other subset of the parameters of your model here\n",
    "    # model = 'custom',\n",
    "    model_config=my_model_config,\n",
    "    network=custom_network,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You are now ready to get samples from your model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " [-----------------100%-----------------] 1000 of 1000 complete in 36.9 sec"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<pymc.MCMC.MCMC at 0x149f01150>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hddmnn_model.sample(1000, burn=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**WARNING**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not all the functionality of the HDDM package will work seamlessly with such custom likelihoods. \n",
    "You will be able to generate some, but not all plots.\n",
    "\n",
    "The utility lies in using HDDM as a vehicle to sample from user defined approximate likelihoods.\n",
    "Most of the packages utility functions have a higher degree of specificity to models that have been fully incorporated into the package.\n",
    "\n",
    "A tutorial concerning full integration into HDDM will be made available in the near future."
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "560f708a8e3e7d8a6040d2821b64a1d977dd58637277f7fff51c695b3b27a4a1"
  },
  "kernelspec": {
   "display_name": "hddmnn_tutorial",
   "language": "python",
   "name": "hddmnn_tutorial"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
